week 7: multilevel models

multilevel adventures

multilevel tadpoles

data(reedfrogs, package = "rethinking")
d <- reedfrogs
dim(d)
[1] 48  5
d %>% sample_n(10)
   density pred  size surv  propsurv
1       35 pred   big    4 0.1142857
2       10 pred small    7 0.7000000
3       10   no   big   10 1.0000000
4       10 pred   big    7 0.7000000
5       35 pred   big   13 0.3714286
6       10 pred   big    9 0.9000000
7       10 pred small    9 0.9000000
8       35   no   big   34 0.9714286
9       25 pred   big    6 0.2400000
10      25   no   big   23 0.9200000

Let’s start with the unpooled model. Up to this point in the course, this would be a good model to use to estimate survival in each of the tanks.

\[\begin{align*} \text{surv}_i &\sim \text{Binomial}(n_i,p_i) \\ \text{logit}(p_i) &= \alpha_{\text{tank}[i]} \\ \alpha_j &\sim \text{Normal}(0, 1.5) \text{ for }j=1,...,48 \end{align*}\]

d$tank = factor(1:nrow(d))

m1 <- 
  brm(data = d, 
      family = binomial,
      bf(surv | trials(density) ~ 0 + alpha, 
         alpha ~ 0 + tank, 
         nl = TRUE),
      prior(normal(0, 1.5), class = b, nlpar=alpha),
      iter = 2000, warmup = 1000, chains = 4, cores = 4,
      seed = 13,
      file = here("files/models/71.1"))
print(m1)
 Family: binomial 
  Links: mu = logit 
Formula: surv | trials(density) ~ 0 + alpha 
         alpha ~ 0 + tank
   Data: d (Number of observations: 48) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
alpha_tank1      1.72      0.75     0.36     3.28 1.00     6727     2922
alpha_tank2      2.39      0.87     0.85     4.24 1.00     5009     2249
alpha_tank3      0.75      0.63    -0.45     2.06 1.00     5975     2721
alpha_tank4      2.41      0.90     0.84     4.40 1.00     5413     2712
alpha_tank5      1.72      0.78     0.34     3.36 1.00     6346     2332
alpha_tank6      1.73      0.75     0.36     3.39 1.00     5673     2882
alpha_tank7      2.40      0.87     0.87     4.24 1.00     6013     2773
alpha_tank8      1.71      0.76     0.32     3.29 1.00     5652     2920
alpha_tank9     -0.37      0.61    -1.60     0.80 1.00     5643     2887
alpha_tank10     1.70      0.75     0.39     3.31 1.00     6262     2260
alpha_tank11     0.74      0.63    -0.46     2.05 1.00     6269     2696
alpha_tank12     0.39      0.63    -0.82     1.64 1.00     5644     2858
alpha_tank13     0.76      0.66    -0.46     2.12 1.00     5430     2848
alpha_tank14     0.01      0.61    -1.16     1.22 1.00     6416     3038
alpha_tank15     1.72      0.76     0.34     3.34 1.00     5830     2712
alpha_tank16     1.72      0.78     0.36     3.42 1.00     5907     2726
alpha_tank17     2.54      0.67     1.36     4.01 1.00     4795     2497
alpha_tank18     2.14      0.61     1.05     3.44 1.00     6054     2619
alpha_tank19     1.80      0.54     0.83     2.92 1.00     6042     3107
alpha_tank20     3.09      0.79     1.72     4.82 1.00     5402     2768
alpha_tank21     2.15      0.62     1.05     3.49 1.00     5870     2806
alpha_tank22     2.14      0.57     1.12     3.36 1.00     5687     2898
alpha_tank23     2.13      0.59     1.10     3.41 1.00     5338     3008
alpha_tank24     1.55      0.51     0.61     2.60 1.00     6469     3086
alpha_tank25    -1.11      0.45    -2.04    -0.26 1.00     5811     2806
alpha_tank26     0.08      0.38    -0.65     0.81 1.00     5386     3054
alpha_tank27    -1.54      0.49    -2.54    -0.63 1.00     5931     2981
alpha_tank28    -0.55      0.40    -1.36     0.20 1.00     5774     2862
alpha_tank29     0.07      0.40    -0.72     0.84 1.00     5880     2852
alpha_tank30     1.32      0.48     0.44     2.30 1.00     5203     2097
alpha_tank31    -0.72      0.42    -1.55     0.09 1.00     7194     3016
alpha_tank32    -0.39      0.42    -1.23     0.40 1.00     5849     2924
alpha_tank33     2.85      0.67     1.71     4.33 1.00     5415     2328
alpha_tank34     2.47      0.59     1.42     3.76 1.00     5835     2769
alpha_tank35     2.46      0.57     1.46     3.68 1.00     5369     2660
alpha_tank36     1.91      0.49     1.02     2.97 1.00     6166     2748
alpha_tank37     1.91      0.49     1.00     2.93 1.00     6123     2860
alpha_tank38     3.37      0.77     2.05     5.04 1.00     5313     2355
alpha_tank39     2.46      0.58     1.43     3.72 1.00     6008     2835
alpha_tank40     2.16      0.53     1.21     3.32 1.00     5945     2463
alpha_tank41    -1.91      0.49    -2.95    -1.02 1.00     6293     2323
alpha_tank42    -0.63      0.35    -1.32     0.04 1.00     6646     3016
alpha_tank43    -0.51      0.34    -1.19     0.16 1.00     5326     3035
alpha_tank44    -0.39      0.33    -1.05     0.24 1.00     6226     3190
alpha_tank45     0.52      0.35    -0.15     1.22 1.00     7301     2643
alpha_tank46    -0.63      0.35    -1.34     0.04 1.00     5798     2948
alpha_tank47     1.91      0.49     1.03     2.91 1.00     5941     2893
alpha_tank48    -0.07      0.34    -0.74     0.59 1.00     7503     2958

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Code
# get posterior
post <- as_draws_df(m1)
# transform logit to probability
p1 <- post %>% 
  pivot_longer(starts_with("b_alpha"),
               names_prefix = "b_alpha_tank",
               values_to = "logit") %>% 
  mutate(prob = logistic(logit),
         tank = as.numeric(name)) %>% 
  ggplot( aes (x = tank, y = prob)) +
  stat_gradientinterval(alpha = .3, color="#5e8485") +
  geom_point( aes(x=as.numeric(tank), y=propsurv),
              data=d) +
  labs(title = "Unpooled model")
p1

Now let’s build up the pooled (multilevel) model and see how it compares.

\[\begin{align*} \text{surv}_i &\sim \text{Binomial}(n_i,p_i) \\ \text{logit}(p_i) &= \alpha_{\text{tank}[i]} \\ \alpha_j &\sim \text{Normal}(\bar{\alpha}, \sigma) \\ \bar{\alpha} &\sim \text{Normal}(0, 1.5) \\ \sigma &\sim \text{Exponential}(1) \end{align*}\]

m2 <- 
  brm(data = d, 
      family = binomial,
      surv | trials(density) ~ 1 + (1 | tank),
      prior = c(prior(normal(0, 1.5), class = Intercept),  # alpha bar
                prior(exponential(1), class = sd)),        # sigma
      iter = 5000, warmup = 1000, chains = 4, cores = 4,
      sample_prior = "yes",
      seed = 13,
      file = here("files/models/71.2"))
print(m2)
 Family: binomial 
  Links: mu = logit 
Formula: surv | trials(density) ~ 1 + (1 | tank) 
   Data: d (Number of observations: 48) 
  Draws: 4 chains, each with iter = 5000; warmup = 1000; thin = 1;
         total post-warmup draws = 16000

Multilevel Hyperparameters:
~tank (Number of levels: 48) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     1.62      0.21     1.25     2.08 1.00     3764     7491

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     1.34      0.25     0.85     1.85 1.00     3111     5830

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
posterior_summary(m2) %>% round(2)
                     Estimate Est.Error    Q2.5   Q97.5
b_Intercept              1.34      0.25    0.85    1.85
sd_tank__Intercept       1.62      0.21    1.25    2.08
Intercept                1.34      0.25    0.85    1.85
r_tank[1,Intercept]      0.79      0.89   -0.80    2.66
r_tank[2,Intercept]      1.71      1.09   -0.16    4.09
r_tank[3,Intercept]     -0.35      0.70   -1.64    1.09
r_tank[4,Intercept]      1.73      1.12   -0.18    4.20
r_tank[5,Intercept]      0.79      0.89   -0.79    2.69
r_tank[6,Intercept]      0.80      0.89   -0.79    2.72
r_tank[7,Intercept]      1.72      1.11   -0.17    4.12
r_tank[8,Intercept]      0.79      0.88   -0.77    2.66
r_tank[9,Intercept]     -1.52      0.66   -2.85   -0.22
r_tank[10,Intercept]     0.80      0.89   -0.78    2.69
r_tank[11,Intercept]    -0.35      0.70   -1.66    1.06
r_tank[12,Intercept]    -0.77      0.66   -2.04    0.57
r_tank[13,Intercept]    -0.34      0.71   -1.69    1.10
r_tank[14,Intercept]    -1.14      0.65   -2.44    0.12
r_tank[15,Intercept]     0.79      0.87   -0.75    2.66
r_tank[16,Intercept]     0.79      0.89   -0.77    2.70
r_tank[17,Intercept]     1.57      0.80    0.15    3.31
r_tank[18,Intercept]     1.05      0.68   -0.20    2.51
r_tank[19,Intercept]     0.67      0.62   -0.45    1.96
r_tank[20,Intercept]     2.35      1.04    0.59    4.64
r_tank[21,Intercept]     1.06      0.70   -0.20    2.54
r_tank[22,Intercept]     1.06      0.70   -0.18    2.54
r_tank[23,Intercept]     1.05      0.68   -0.17    2.53
r_tank[24,Intercept]     0.36      0.58   -0.73    1.53
r_tank[25,Intercept]    -2.34      0.51   -3.37   -1.38
r_tank[26,Intercept]    -1.18      0.46   -2.10   -0.27
r_tank[27,Intercept]    -2.77      0.55   -3.90   -1.75
r_tank[28,Intercept]    -1.81      0.47   -2.75   -0.92
r_tank[29,Intercept]    -1.18      0.46   -2.10   -0.28
r_tank[30,Intercept]     0.11      0.54   -0.92    1.21
r_tank[31,Intercept]    -1.98      0.48   -2.94   -1.06
r_tank[32,Intercept]    -1.65      0.47   -2.58   -0.74
r_tank[33,Intercept]     1.85      0.78    0.47    3.54
r_tank[34,Intercept]     1.37      0.68    0.16    2.84
r_tank[35,Intercept]     1.37      0.68    0.16    2.83
r_tank[36,Intercept]     0.72      0.56   -0.33    1.89
r_tank[37,Intercept]     0.72      0.56   -0.33    1.86
r_tank[38,Intercept]     2.56      0.98    0.92    4.71
r_tank[39,Intercept]     1.37      0.68    0.16    2.80
r_tank[40,Intercept]     1.01      0.61   -0.11    2.29
r_tank[41,Intercept]    -3.15      0.54   -4.28   -2.15
r_tank[42,Intercept]    -1.91      0.42   -2.77   -1.10
r_tank[43,Intercept]    -1.79      0.42   -2.64   -0.98
r_tank[44,Intercept]    -1.68      0.43   -2.54   -0.85
r_tank[45,Intercept]    -0.76      0.43   -1.60    0.08
r_tank[46,Intercept]    -1.91      0.43   -2.77   -1.09
r_tank[47,Intercept]     0.72      0.55   -0.30    1.86
r_tank[48,Intercept]    -1.34      0.42   -2.15   -0.52
prior_Intercept          0.00      1.51   -2.98    2.96
prior_sd_tank            0.99      0.98    0.02    3.69
lprior                  -3.36      0.29   -3.99   -2.88
lp__                  -157.11      6.65 -170.89 -144.91

The “intercepts” for each tank are actually the distance of that tank’s intercept from the grand mean.

gather_draws(m2, r_tank[tank, ]) %>% 
  mean_qi() %>% 
  ggplot(aes( y=tank, x=.value )) +
  geom_point() +
  geom_errorbar( aes(xmin=.lower, xmax=.upper ), alpha=.5) +
  geom_vline(xintercept = 0) +
  labs(x="Varying intercepts")
Code
m1 <- add_criterion(m1, "waic")
m2 <- add_criterion(m2, "waic")

w <- loo_compare(m1, m2, criterion = "waic")

print(w, simplify = F)
   elpd_diff se_diff elpd_waic se_elpd_waic p_waic se_p_waic waic   se_waic
m2    0.0       0.0  -100.2       3.7         21.1    0.9     200.5    7.3 
m1   -6.8       1.8  -107.0       2.3         25.3    1.2     214.0    4.7 
Code
#average survival
post_sum = posterior_summary(m2)
average_surv = post_sum["b_Intercept", "Estimate"] 
p2 <- gather_draws(m2, r_tank[tank, Intercept]) %>% 
  mutate(prob = logistic(.value+average_surv)) %>% 
  ggplot( aes (x = tank, y = prob)) +
  stat_gradientinterval(alpha = .3, color="#5e8485") +
  geom_point( aes(x=as.numeric(tank), y=propsurv),
              data=d ) +
  geom_hline( aes(yintercept = logistic(average_surv)),
              linetype = "dashed")+
  labs(title = "Partial pooling model")

p2

divergent transitions

From McElreath:

Recall that HMC simulates the frictionless flow of a particle on a surface. In any given transition, which is just a single flick of the particle, the total energy at the start should be equal to the total energy at the end. That’s how energy in a closed system works. And in a purely mathematical system, the energy is always conserved correctly. It’s just a fact about the physics.

But in a numerical system, it might not be. Sometimes the total energy is not the same at the end as it was at the start. In these cases, the energy is divergent. How can this happen? It tends to happen when the posterior distribution is very steep in some region of parameter space. Steep changes in probability are hard for a discrete physics simulation to follow. When that happens, the algorithm notices by comparing the energy at the start to the energy at the end. When they don’t match, it indicates numerical problems exploring that part of the posterior distribution.

centered parameterization

In his lecture, McElreath uses CENTERED PARAMETERIZATION to demonstrate divergent transitions. A very simple example:

\[\begin{align*} x &\sim \text{Normal}(0, exp(\nu)) \\ \nu &\sim \text{Normal}(0, 3) \\ \end{align*}\]

This expression is centered because one set of priors (the priors for \(x\)) are centered around another prior (the prior for \(\nu\)). It’s intuitive, but this can cause a lot of problems with Stan, which is probably why McElreath used this for his example. In short, when there is limited data within our groups or the population variance is small, the parameters \(x\) and \(\nu\) become highly correlated. This geometry is challenging for MCMC to sample. (Think of a long and narrow groove, not a bowl, for your Hamiltonian skateboard.)

Code
set.seed(1)
# plot the likelihoods
ps <- seq( from=-4, to=4, length.out=200) # possible parameter values for both x and nu

crossing(nu = ps, x=ps) %>%  #every possible combination of nu and x
  mutate(
    likelihood_nu = dnorm(nu, 0, 3),
    likelihood_x  = dnorm(x, 0, exp(nu)),
    joint_likelihood = likelihood_nu*likelihood_x
  ) %>% 
  ggplot( aes(x=x, y=nu, fill=joint_likelihood) ) +
  geom_raster() + 
  scale_fill_viridis_c() +
  guides(fill = F)

The way to fix this is by using an uncentered parameterization:

\[\begin{align*} x &= z\times \text{exp}(\nu) \\ z &\sim \text{Normal}(0, 1) \\ \nu &\sim \text{Normal}(0, 3) \\ \end{align*}\]

Code
set.seed(1)
# plot the likelihoods
ps <- seq( from=-4, to=4, length.out=200) # possible parameter values for both x and nu

crossing(nu = ps, z=ps) %>%  #every possible combination of nu and x
  mutate(
    likelihood_nu = dnorm(nu, 0, 3),
    likelihood_z  = dnorm(z, 0, 1),
    joint_likelihood = likelihood_nu*likelihood_z
  ) %>% 
  ggplot( aes(x=z, y=nu, fill=joint_likelihood) ) +
  geom_raster() +
  scale_fill_viridis_c() +
  guides(fill = F)

It’s an important point, except the issues of centered parameterization are so prevalent1, that brms generally doesn’t allow centered parameterization (with some exceptions). So we can’t recreate the divergent transition situation that McElreath demonstrates in his lecture.

McElreath describes the problem of fertility in Bangladesh as such:

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha_{D_{[i]}} \\ \alpha_j &\sim \text{Normal}(\bar{\alpha}, \sigma) \\ \bar{\alpha} &\sim \text{Normal}(0, 1) \\ \sigma &\sim \text{Exponential}(1) \\ \end{align*}\]

But to fit this using brms, we’ll rewrite as:

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha + \alpha_{D[i]} \\ \alpha &\sim \text{Normal}(0, 1) \\ \alpha_{D[j]} &\sim \text{Normal}(0, \sigma_{D}) \\ \sigma_{D} &\sim \text{Exponential}(1) \end{align*}\]

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha + \alpha_{D[i]} \\ \alpha &\sim \text{Normal}(0, 1) \\ \alpha_{D[j]} &\sim \text{Normal}(0, \sigma_{D}) \\ \sigma_{D} &\sim \text{Exponential}(1) \end{align*}\]

data(bangladesh, package="rethinking")
d <- bangladesh

m1 <- brm(
  data=d,
  family=bernoulli,
  use.contraception ~ 1 + (1 | district),
  prior = c( prior(normal(0, 1), class = Intercept), # alpha bar
             prior(exponential(1), class = sd)),       # sigma

  chains=4, cores=4, iter=2000, warmup=1000,
  seed = 1,
  file = here("files/data/generated_data/m71.1"))
m1
 Family: bernoulli 
  Links: mu = logit 
Formula: use.contraception ~ 1 + (1 | district) 
   Data: d (Number of observations: 1934) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~district (Number of levels: 60) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.52      0.09     0.37     0.70 1.00     1374     1915

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    -0.54      0.09    -0.72    -0.37 1.00     1998     2342

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
gather_draws(m1, b_Intercept, r_district[district, ]) %>% 
  with_groups(c(.variable, district), median_qi, .value)
# A tibble: 61 × 8
# Groups:   .variable, district [61]
   .variable   district  .value .lower  .upper .width .point .interval
   <chr>          <int>   <dbl>  <dbl>   <dbl>  <dbl> <chr>  <chr>    
 1 b_Intercept       NA -0.536  -0.715 -0.369    0.95 median qi       
 2 r_district         1 -0.454  -0.864 -0.0464   0.95 median qi       
 3 r_district         2 -0.0482 -0.757  0.610    0.95 median qi       
 4 r_district         3  0.301  -0.702  1.35     0.95 median qi       
 5 r_district         4  0.343  -0.239  0.964    0.95 median qi       
 6 r_district         5 -0.0297 -0.592  0.510    0.95 median qi       
 7 r_district         6 -0.275  -0.773  0.197    0.95 median qi       
 8 r_district         7 -0.216  -0.945  0.478    0.95 median qi       
 9 r_district         8  0.0236 -0.567  0.603    0.95 median qi       
10 r_district         9 -0.162  -0.866  0.453    0.95 median qi       
# ℹ 51 more rows
Code
gather_draws(m1, b_Intercept, r_district[district, ]) %>% 
  with_groups(c(.variable, district), median_qi, .value) %>% 
  ggplot(aes( x=district, y=.value)) +
  geom_pointinterval( aes(ymin = .lower, ymax = .upper), 
                      alpha=.5) +
  labs(y="District distance from mean") +
  coord_flip()

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha + \alpha_{D[i]} + \beta U_i + \beta_{D[i]}U_i \\ \alpha, \beta &\sim \text{Normal}(0, 1) \\ \alpha_{D[j]} &\sim \text{Normal}(0, \sigma_{D}) \\ \beta_{D[j]} &\sim \text{Normal}(0, \tau_{D}) \\ \sigma, \tau &\sim \text{Exponential}(1) \\ \end{align*}\]

m2 <- brm(
  data=d,
  family=bernoulli,
  use.contraception ~ 1 + urban + (1 + urban || district),
  prior = c( prior(normal(0, 1), class = Intercept), 
             prior(normal(0, 1), class = b),
             prior(exponential(1), class = sd)),     

  chains=4, cores=4, iter=2000, warmup=1000,
  seed = 1,
  file = here("files/data/generated_data/m71.2"))

Oops, no divergent transitions.

m2
 Family: bernoulli 
  Links: mu = logit 
Formula: use.contraception ~ 1 + urban + (1 + urban || district) 
   Data: d (Number of observations: 1934) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~district (Number of levels: 60) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.48      0.09     0.32     0.67 1.01     1290     2067
sd(urban)         0.55      0.21     0.11     0.96 1.00      860      912

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    -0.70      0.09    -0.88    -0.53 1.00     2275     2893
urban         0.63      0.15     0.33     0.92 1.00     2391     2077

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

more about divergent transitions

From Gelman et al (2020)